from tensorflow.python.keras.utils.data_utils import get_file
import os 
import numpy as np
# from google_drive_downloader import GoogleDriveDownloader as gdd

# credit for https://www.microsoft.com/en-us/download/details.aspx?id=54765

IMG_SIZE = 128

def load_data(path):
    tr_data_path = os.path.join(path,'TIMG_train_X.npy')
    tr_label_path = os.path.join(path, 'TIMG_train_Y.npy')
    te_data_path = os.path.join(path,'TIMG_test_X.npy')
    te_label_path = os.path.join(path, 'TIMG_test_Y.npy')
    train_data = np.load(tr_data_path)
    train_label = np.expand_dims(np.load(tr_label_path),axis=1).astype(int)
    train_data = np.transpose(train_data, (0,3,1,2))
    N = train_label.shape[0]
    np.random.seed(0)
    randids = np.random.permutation(N)
    train_data = train_data[randids]
    train_label = train_label[randids]
    test_data = np.load(te_data_path)
    test_label = np.expand_dims(np.load(te_label_path),axis=1).astype(int)
    test_data = np.transpose(test_data, (0,3,1,2))
    print(train_data.shape)
    print(test_data.shape)
    return train_data, train_label, test_data, test_label


def TIMG():
    
    
    train_X, train_Y,  test_X, test_Y = load_data('/home/dixzhu/source/models-master/research/resnet/data/')

    # convert data type
    train_X, train_Y = train_X.astype(float), train_Y.astype(np.int32) 
    test_X, test_Y = test_X.astype(float), test_Y.astype(np.int32) 
    
    return  (train_X, train_Y), (test_X, test_Y) 


